#include <Columns/ColumnArray.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnQBit.h>

#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeFixedString.h>
#include <DataTypes/DataTypeQBit.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/Serializations/SerializationQBit.h>

#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>

#include <Interpreters/Context.h>
#include <Interpreters/castColumn.h>

/// Include immintrin. Otherwise `simsimd` fails to build: `unknown type name '__bfloat16'`
#if USE_SIMSIMD
#    if defined(__x86_64__) || defined(__i386__)
#        include <immintrin.h>
#    endif
#    include <simsimd/simsimd.h>
#endif


namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION;
extern const int TOO_MANY_ARGUMENTS_FOR_FUNCTION;
}

struct L2DistanceTransposed
{
    static constexpr auto name = "L2DistanceTransposed";
    struct ConstParams
    {
        UInt8 groups;
    };
};

/** L2DistanceTransposed has two calling conventions:
  * 1. User-facing (documented): L2DistanceTransposed(qbit, ref_vec, precision)
  * 2. Internal (undocumented): L2DistanceTransposed(vec.1, ..., vec.precision, qbit_size, ref_vec)
  *
  * The second form is generated by L2DistanceTransposedPartialReadsPass for partial column reads.
  * It is not exposed in documentation and users should not call it directly.
  *
  * IMPORTANT: In the second form, ref_vec type must match the original QBit element type
  * (BFloat16/Float32/Float64). This is the only way to determine the QBit type since we
  * only receive individual bit planes. Type mismatches will produce incorrect results.
  */

template <typename Kernel>
class FunctionArrayDistance : public IFunction
{
public:
    String getName() const override { return Kernel::name; }
    static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayDistance<Kernel>>(); }
    bool isVariadic() const override { return true; }
    size_t getNumberOfArguments() const override { return 0; }
    ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {}; }
    bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
    bool useDefaultImplementationForConstants() const override { return true; }

    template <typename T>
    static void l2Distance(const T * __restrict x, const T * __restrict y, std::size_t array_size, Float64 * result)
    {
        /// Benchmarks show simsimd has great performance. We do not need CPU dispatch because SimSimd provides it's own dynamic dispatch
#if USE_SIMSIMD
        if constexpr (std::is_same_v<T, BFloat16>)
            simsimd_l2_bf16(reinterpret_cast<const simsimd_bf16_t *>(x), reinterpret_cast<const simsimd_bf16_t *>(y), array_size, result);
        else if constexpr (std::is_same_v<T, Float32>)
            simsimd_l2_f32(x, y, array_size, result);
        else if constexpr (std::is_same_v<T, Float64>)
            simsimd_l2_f64(x, y, array_size, result);
        return;
#endif

        /// Fallback to scalar implementation if simsimd is not available. It also originates from simsimd, but is decoupled
        if constexpr (std::is_same_v<T, BFloat16>)
            l2DistanceScalar<BFloat16, Float32>(x, y, array_size, result);
        else if constexpr (std::is_same_v<T, Float32>)
            l2DistanceScalar<Float32, Float32>(x, y, array_size, result);
        else if constexpr (std::is_same_v<T, Float64>)
            l2DistanceScalar<Float64, Float64>(x, y, array_size, result);
    }

    template <typename InputType, typename AccumulatorType>
    static void l2DistanceScalar(const InputType * __restrict x, const InputType * __restrict y, std::size_t array_size, Float64 * result)
    {
        /// This could be vectorized, but we consider this a fallback code path, so no need to optimize it heavily
        AccumulatorType d2 = 0;
        for (size_t i = 0; i != array_size; ++i)
        {
            AccumulatorType xi = static_cast<AccumulatorType>(*(x + i));
            AccumulatorType yi = static_cast<AccumulatorType>(*(y + i));
            d2 += (xi - yi) * (xi - yi);
        }
        *result = static_cast<Float64>(sqrt(d2));
    }

    DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
    {
        if (arguments.size() < 3)
            throw Exception(
                ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION,
                "Number of arguments for function {} can't be {}, should be at least 3",
                getName(),
                arguments.size());

        /// Check if we are in optimised L2DistanceTransposed(vec.1, ..., vec.p, qbit_size, ref_vec) case. If something goes wrong, we
        /// fallback to the original L2DistanceTransposed(qbit, ref_vec, p) handling. The arguments in optimised case are generated by us
        /// and are almost certainly correct. It is extremely unlikely that user will write optimised case manually. Thus, any error in
        /// arguments is treated as user error from the original case.
        if (validateOptimizedArguments(arguments))
            return std::make_shared<DataTypeFloat64>();

        if (arguments.size() > 3)
            throw Exception(
                ErrorCodes::TOO_MANY_ARGUMENTS_FOR_FUNCTION,
                "Number of arguments for function {} is {}. Expected 3",
                getName(),
                arguments.size());

        /// Check the first two arguments
        const auto * zeroth_arg_type = checkAndGetDataType<DataTypeQBit>(arguments[0].type.get());
        const auto * first_arg_type = checkAndGetDataType<DataTypeArray>(arguments[1].type.get());

        if (!zeroth_arg_type)
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be a QBit", getName());

        if (!first_arg_type)
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument of function {} must be an Array", getName());

        /// Check that precision (third argument) is valid
        const auto & precision_col = arguments[2];
        WhichDataType which(precision_col.type);

        if (!which.isUInt8())
            throw Exception(
                ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
                "The third argument of function {} must be a UInt8 constant, got {}",
                getName(),
                precision_col.type->getName());

        if (!(precision_col.column && precision_col.column->isConst()))
            throw Exception(
                ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
                "The third argument of function {} must be a UInt8 constant, got {} (not constant)",
                getName(),
                precision_col.type->getName());

        const auto precision = precision_col.column->getUInt(0);
        const auto element_size = zeroth_arg_type->getElementSize();

        if (precision == 0 || precision > element_size)
            throw Exception(
                ErrorCodes::BAD_ARGUMENTS,
                "The third argument (precision) of function {} must be in range [1, {}] for {} QBit, got {}",
                getName(),
                element_size,
                zeroth_arg_type->getElementType()->getName(),
                precision);

        return std::make_shared<DataTypeFloat64>();
    }

    /// Validates arguments for optimised L2DistanceTransposed(vec.1, ..., vec.p, qbit_size, ref_vec) case
    bool validateOptimizedArguments(const ColumnsWithTypeAndName & arguments) const
    {
        constexpr size_t max_precision = 64;
        const size_t precision = arguments.size() - 2;
        const auto * ref_vec_type = checkAndGetDataType<DataTypeArray>(arguments.back().type.get());
        const auto & qbit_size_column = (arguments.end() - 2)->column;
        const auto & qbit_size_arg_type = (arguments.end() - 2)->type;
        const WhichDataType which_qbit_size_arg_type(qbit_size_arg_type);

        /// Note: we only allow constant qbit_size_column
        if (!ref_vec_type || precision > max_precision || !qbit_size_column || !qbit_size_column->isConst()
            || !which_qbit_size_arg_type.isUInt())
            return false;

        const auto ref_vec_type_id = ref_vec_type->getNestedType()->getTypeId();
        if (ref_vec_type_id != TypeIndex::BFloat16 && ref_vec_type_id != TypeIndex::Float32 && ref_vec_type_id != TypeIndex::Float64)
            return false;

        const auto qbit_size = qbit_size_column->getUInt(0);
        const auto qbit_size_bytes = DataTypeQBit::bitsToBytes(qbit_size);

        /// All QBit subcolumns should be FixedString and have a consistent size
        for (size_t i = 0; i < precision; ++i)
        {
            const auto * arg_type = checkAndGetDataType<DataTypeFixedString>(arguments[i].type.get());

            if (!arg_type || arg_type->getN() != qbit_size_bytes)
                return false;
        }

        /// The type of reference vector dictates what type QBit had before we sliced it into q.1, ..., q.precision.
        /// Check that the number of bit planes doesn't exceed the maximum precision for the reference vector type.
        size_t max_precision_for_type = ref_vec_type->getNestedType()->getSizeOfValueInMemory() * 8;

        if (precision > max_precision_for_type)
            return false;

        return true;
    }


    ColumnPtr
    executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & /* result_type */, size_t input_rows_count) const override
    {
        const auto & last_arg = arguments.back();

        /// If last argument is UInt, we are in L2DistanceTransposed(qbit, ref_vec, p) case
        WhichDataType which_last(last_arg.type);
        if (which_last.isUInt8())
            return executeWithQBitColumnConverted(arguments, input_rows_count);

        /// Otherwise, L2DistanceTransposed(vec.1, ..., vec.p, qbit_size, ref_vec)

        /// First, check that the reference vector sizes match qbit size
        const ColumnArray & reference_vector = *assert_cast<const ColumnArray *>(extractFromConst(arguments.back().column).get());
        const auto qbit_size = (arguments.end() - 2)->column->getUInt(0);
        const auto & offsets = reference_vector.getOffsets();

        /// In dry run, the offsets can be empty with non-constant reference vector
        if (!offsets.empty())
        {
            for (size_t i = 0; i < reference_vector.size(); ++i)
            {
                if (offsets[i] - offsets[i - 1] != qbit_size)
                    throw Exception(
                        ErrorCodes::BAD_ARGUMENTS,
                        "The reference vector in the last argument of function {} has wrong size. Got: {}, expected: {}",
                        getName(),
                        offsets[i] - offsets[i - 1],
                        qbit_size);
            }
        }

        /// Continue with execution
        auto type_y = checkAndGetDataType<DataTypeArray>(last_arg.type.get())->getNestedType()->getTypeId();
        const size_t precision = arguments.size() - 2;

        /// We need to find two types: the type of the reference vector and the type of the calculation.
        /// The type of calculation is determined by the value of `precision. For example, if col_x is Float32 and p = 16, we will only have
        /// 16 meaningful bits to calculate the distance. So we can downcast the reference vector to BFloat16 and do calculations faster.
        auto dispatch_by_accum_type = [&]<typename RefT>(auto func)
        {
            auto calc_type
                = (precision <= 16 ? TypeToTypeIndex<BFloat16> : (precision <= 32 ? TypeToTypeIndex<Float32> : TypeToTypeIndex<Float64>));

            /// Float64 cannot be downcasted to Float32 or BFloat16 in an easy way by reordering bits. That is why with it we always do
            /// calculations in full width. Alternatively, we could static_cast each element when calculating, but it is slower.
            if (std::is_same_v<RefT, Float64>)
                return func.template operator()<RefT, Float64>();
            else if (calc_type == TypeToTypeIndex<Float32>)
                return func.template operator()<RefT, Float32>();
            else if (calc_type == TypeToTypeIndex<BFloat16>)
                return func.template operator()<RefT, BFloat16>();
            else
                UNREACHABLE();
        };

        auto execute_with_type = [&]<typename T>() -> ColumnPtr
        {
            return dispatch_by_accum_type.template operator()<T>(
                [&]<typename RefT, typename CalcT>()
                { return executeDistanceCalculation<RefT, CalcT>(reference_vector, arguments, qbit_size, input_rows_count); });
        };

        /// Dispatch to type-specific implementation based on reference vector type
        switch (type_y)
        {
            case TypeIndex::BFloat16:
                return execute_with_type.template operator()<BFloat16>();
            case TypeIndex::Float32:
                return execute_with_type.template operator()<Float32>();
            case TypeIndex::Float64:
                return execute_with_type.template operator()<Float64>();
            default:
                UNREACHABLE();
        }
    }


private:
    static ColumnPtr extractFromConst(const ColumnPtr & column)
    {
        return column->isConst() ? assert_cast<const ColumnConst *>(column.get())->getDataColumnPtr() : column;
    }

    /// L2DistanceTransposed(qbit, ref_vec, p) case. Convert arguments to [qbit.1, ..., qbit.p, ref_vec] format before executing
    ColumnPtr executeWithQBitColumnConverted(const ColumnsWithTypeAndName & arguments, size_t input_rows_count) const
    {
        ColumnsWithTypeAndName converted_arguments;

        const auto precision = arguments[2].column->getUInt(0);
        const auto * qbit_type = assert_cast<const DataTypeQBit *>(arguments[0].type.get());
        const auto * qbit_ptr = assert_cast<const ColumnQBit *>(extractFromConst(arguments[0].column).get());
        const auto qbit_dimension = qbit_type->getDimension();
        const auto & qbit_tuple = assert_cast<const ColumnTuple &>(qbit_ptr->getTupleColumn());
        const auto bit_plane_type = qbit_type->getNestedTupleElementType();

        for (size_t bit = 0; bit < precision; ++bit)
            converted_arguments.emplace_back(qbit_tuple.getColumn(bit).getPtr(), bit_plane_type, toString(bit + 1));
        /// Add dimension as penultimate argument and reference vector as last argument
        auto dimension_column = DataTypeUInt64().createColumnConst(1, qbit_dimension);
        converted_arguments.emplace_back(dimension_column, std::make_shared<DataTypeUInt64>(), "dimension");

        /// Cast reference vector to match QBit element type to ensure correct dispatch
        auto ref_vec_type = arguments[1].type;
        auto expected_ref_vec_type = std::make_shared<DataTypeArray>(qbit_type->getElementType());

        if (ref_vec_type->equals(*expected_ref_vec_type))
        {
            converted_arguments.emplace_back(arguments[1]);
        }
        else
        {
            auto casted_column = castColumn(arguments[1], expected_ref_vec_type);
            converted_arguments.emplace_back(casted_column, expected_ref_vec_type, arguments[1].name);
        }

        /// We go back to the function that called us, but now with converted arguments
        return executeImpl(converted_arguments, nullptr, input_rows_count);
    }

    /// RefT is the type of the reference vector, CalcT is the type used for calculation (can be downcasted from RefT if p is low enough)
    template <typename RefT, typename CalcT>
    ColumnPtr executeDistanceCalculation(
        const ColumnArray & col_y, const ColumnsWithTypeAndName & arguments, const size_t qbit_size, size_t input_rows_count) const
    {
        const size_t precision = arguments.size() - 2;
        const size_t bytes_per_fixedstring = DataTypeQBit::bitsToBytes(qbit_size);
        const size_t padded_array_size = bytes_per_fixedstring * 8;

        /// For the sake of speed, downcast the reference vector to CalcT `precision` is low enough
        const auto & array_data = static_cast<const ColumnVector<RefT> &>(col_y.getData()).getData();
        const PaddedPODArray<CalcT> * data_ptr;
        PaddedPODArray<CalcT> array_data_downcasted;
        if constexpr (!std::is_same_v<RefT, CalcT>)
        {
            array_data_downcasted.resize(array_data.size());
            for (size_t i = 0; i < array_data.size(); ++i)
                array_data_downcasted[i] = static_cast<CalcT>(array_data[i]);
            data_ptr = &array_data_downcasted;
        }
        else
        {
            data_ptr = &array_data;
        }

        auto col_res = ColumnVector<Float64>::create(input_rows_count);
        auto & result_data = col_res->getData();

        using Word = std::conditional_t<sizeof(CalcT) == 2, UInt16, std::conditional_t<sizeof(CalcT) == 4, UInt32, UInt64>>;

        /// We process 32 rows per iteration. It's a magic number, but gives a good trade-off between memory usage and performance
        constexpr size_t block_size = 32;
        std::vector<CalcT> block(block_size * padded_array_size);
        auto block_row = [&](size_t r) -> CalcT * { return block.data() + r * padded_array_size; };

        for (size_t base_row = 0; base_row < input_rows_count; base_row += block_size)
        {
            const size_t rows_in_block = std::min(block_size, input_rows_count - base_row);

            memset(block.data(), 0, rows_in_block * padded_array_size * sizeof(CalcT));

            /// Untranspose p bit planes into all rows of the block
            for (size_t bit = 0; bit < precision; ++bit)
            {
                const auto & col = assert_cast<const ColumnFixedString &>(*extractFromConst(arguments[bit].column));
                Word bit_mask = Word(1) << (sizeof(Word) * 8 - 1 - bit);

                for (size_t r = 0; r < rows_in_block; ++r)
                {
                    const UInt8 * src = reinterpret_cast<const UInt8 *>(col.getChars().data()) + (base_row + r) * bytes_per_fixedstring;

                    SerializationQBit::untransposeBitPlane(src, reinterpret_cast<Word *>(block_row(r)), padded_array_size, bit_mask);
                }
            }

            /// Calculate L2 distance
            for (size_t r = 0; r < rows_in_block; ++r)
            {
                auto * dst = block_row(r);

                if constexpr (std::is_same_v<CalcT, BFloat16>)
                    l2Distance(dst, data_ptr->data(), qbit_size, &result_data[base_row + r]);
                else if constexpr (std::is_same_v<CalcT, Float32>)
                    l2Distance(dst, data_ptr->data(), qbit_size, &result_data[base_row + r]);
                else if constexpr (std::is_same_v<CalcT, Float64>)
                    l2Distance(dst, data_ptr->data(), qbit_size, &result_data[base_row + r]);
                else
                    UNREACHABLE();
            }
        }

        return col_res;
    }
};

/// Used by TupleOrArrayFunction
FunctionPtr createFunctionArrayL2DistanceTransposed(ContextPtr context_)
{
    return FunctionArrayDistance<L2DistanceTransposed>::create(context_);
}
}
